from datetime import datetime, timedelta, time
from collections.abc import Callable
import os
import json
from pathlib import Path

from pandas import DataFrame
from filelock import FileLock, Timeout

from vnpy.trader.setting import SETTINGS
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData, HistoryRequest
from vnpy.trader.utility import ZoneInfo, get_file_path
from vnpy.trader.datafeed import BaseDatafeed

from .xt_config import VIP_ADDRESS_LIST, LISTEN_PORT

# ⚠️ CRITICAL: 在导入xtquant之前检查client模式
# 不能依赖SETTINGS，因为main.py在导入模块后才配置SETTINGS
# 必须直接读取config.json文件
_is_client_mode = False
try:
    # 尝试查找 config.json 文件
    possible_config_paths = [
        Path.cwd() / "config.json",  # 当前工作目录
        Path.cwd() / "liugejiao_qt" / "config.json",  # 项目目录
        Path(__file__).parent.parent.parent / "liugejiao_qt" / "config.json",  # 相对于vnpy_xt的路径
    ]
    
    config_data = None
    for config_path in possible_config_paths:
        if config_path.exists():
            with open(config_path, 'r', encoding='utf-8') as f:
                config_data = json.load(f)
                print(f"🔧 [xt_datafeed] 找到配置文件: {config_path}")
                break
    
    if config_data:
        datafeed_username = config_data.get("datafeed", {}).get("username", "")
        _is_client_mode = datafeed_username == "client"
        if _is_client_mode:
            # Use hardcoded values since SETTINGS is not available yet
            os.environ["XT_LOCAL_MODE"] = '1'
            os.environ["XT_DISABLE_REMOTE"] = '1'
            print(f"🔧 [xt_datafeed] Client模式环境配置完成 (username='{datafeed_username}')")
        else:
            print(f"🔧 [xt_datafeed] Token模式 (username='{datafeed_username}')")
except Exception as e:
    print(f"⚠️  [xt_datafeed] 无法读取配置文件: {e}")
    _is_client_mode = False

# 现在导入xtquant
from xtquant import xtdata, xtdatacenter as xtdc

# ⚠️ CRITICAL: Client模式下立即清空远程服务器列表（必须在任何xtdata操作之前）
if _is_client_mode:
    try:
        # 方法1: 清空允许的优化地址
        xtdc.set_allow_optmize_address([])
        print("✅ [xt_datafeed] 已清空远程服务器列表（Client模式）")
        
        # 方法2: Monkey patch xtdc.init to prevent remote initialization  
        _original_xtdc_init = xtdc.init
        def _client_mode_init(*args, **kwargs):
            print("⚠️  [xt_datafeed] 拦截xtdc.init()调用（Client模式下禁止远程初始化）")
            return None
        xtdc.init = _client_mode_init
        print("✅ [xt_datafeed] 已拦截xtdc.init()方法")
        
    except Exception as e:
        print(f"⚠️  [xt_datafeed] 清空远程服务器列表时出错: {e}")

# 禁用 xtdata 的 hello 消息
# 禁用 xtdata 的 hello 消息
xtdata.enable_hello = False


INTERVAL_VT2XT: dict[Interval, str] = {
    Interval.MINUTE: "1m",
    Interval.DAILY: "1d",
    Interval.TICK: "tick"
}

INTERVAL_ADJUSTMENT_MAP: dict[Interval, timedelta] = {
    Interval.MINUTE: timedelta(minutes=1),
    Interval.DAILY: timedelta()         # 日线无需进行调整
}

EXCHANGE_VT2XT: dict[Exchange, str] = {
    Exchange.SSE: "SH",
    Exchange.SZSE: "SZ",
    Exchange.BSE: "BJ",
    Exchange.SHFE: "SF",
    Exchange.CFFEX: "IF",
    Exchange.INE: "INE",
    Exchange.DCE: "DF",
    Exchange.CZCE: "ZF",
    Exchange.GFEX: "GF",
}

CHINA_TZ = ZoneInfo("Asia/Shanghai")


class XtDatafeed(BaseDatafeed):
    """迅投研数据服务接口"""

    lock_filename = "xt_lock"
    lock_filepath = get_file_path(lock_filename)

    def __init__(self) -> None:
        """"""
        self.username: str = SETTINGS["datafeed.username"]
        self.password: str = SETTINGS["datafeed.password"]
        self.inited: bool = False

        self.lock: FileLock | None = None

    def init(self, output: Callable = print) -> bool:
        """初始化"""
        if self.inited:
            return True

        try:
            # Client模式：使用本地QMT客户端，不初始化远程连接
            if self.username == "client":
                output("🔗 使用Client模式 - 连接本地QMT客户端")
                # 不调用 init_xtdc()，避免连接远程服务器
                # xtdata 会自动检测本地 QMT 客户端
                pass
            else:
                # Token模式：使用Token连接远程服务器
                output(f"🔗 使用Token模式 - 连接远程服务器 (username={self.username})")
                self.init_xtdc()

            # 尝试查询合约信息，确认连接成功
            test_result = xtdata.get_instrument_detail("000001.SZ")
            if test_result:
                output("✅ 迅投数据服务初始化成功")
            else:
                output("⚠️  查询合约返回空数据，但未报错")
                
        except Exception as ex:
            output(f"❌ 迅投研数据服务初始化失败，发生异常：{ex}")
            if self.username == "client":
                output("")
                output("💡 Client模式需要:")
                output("   1. 确保 XtMiniQmt.exe 正在运行")
                output("   2. 确保已登录账户")
                output("   3. 等待QMT客户端完全启动 (10-30秒)")
                output("   4. 检查config.json中datafeed配置:")
                output("      - username: \"client\"")
                output("      - password: \"\" (留空)")
            else:
                output("")
                output("💡 Token模式需要:")
                output("   1. 确保 token 有效")
                output("   2. 检查网络连接")
                output(f"   3. username={self.username}")
            return False

        self.inited = True
        return True

    def get_lock(self) -> bool:
        """获取文件锁，确保单例运行"""
        # Use configurable timeout from settings
        lock_timeout = 1
        self.lock = FileLock(self.lock_filepath)

        try:
            self.lock.acquire(timeout=lock_timeout)
            return True
        except Timeout:
            return False

    def init_xtdc(self) -> None:
        """初始化xtdc服务进程"""
        if not self.get_lock():
            return

        # 设置token
        xtdc.set_token(self.password)

        # 设置连接池
        xtdc.set_allow_optmize_address(VIP_ADDRESS_LIST)

        # 开启使用期货真实夜盘时间
        xtdc.set_future_realtime_mode(True)

        # 执行初始化，但不启动默认58609端口监听
        xtdc.init(False)

        # 设置监听端口
        xtdc.listen(port=LISTEN_PORT)

    def query_bar_history(self, req: HistoryRequest, output: Callable = print) -> list[BarData] | None:
        """查询K线数据"""
        history: list[BarData] = []

        if not self.inited:
            n: bool = self.init(output)
            if not n:
                return history

        df: DataFrame = get_history_df(req, output)
        if df.empty:
            return history

        adjustment: timedelta = INTERVAL_ADJUSTMENT_MAP.get(req.interval, timedelta())

        # 遍历解析
        auction_bar: BarData | None = None

        for tp in df.itertuples():
            # 将迅投研时间戳（K线结束时点）转换为VeighNa时间戳（K线开始时点）
            dt: datetime = datetime.fromtimestamp(tp.time / 1000)
            dt = dt.replace(tzinfo=CHINA_TZ)
            dt = dt - adjustment

            # 日线，过滤尚未走完的当日数据
            if req.interval == Interval.DAILY:
                incomplete_bar: bool = (
                    dt.date() == datetime.now().date()
                    and datetime.now().time() < time(hour=15)
                )
                if incomplete_bar:
                    continue
            # 分钟线，过滤盘前集合竞价数据（合并到开盘后第1根K线中）
            else:
                if (
                    req.exchange in (Exchange.SSE, Exchange.SZSE, Exchange.BSE, Exchange.CFFEX)
                    and dt.time() == time(hour=9, minute=29)
                ) or (
                    req.exchange in (Exchange.SHFE, Exchange.INE, Exchange.DCE, Exchange.CZCE, Exchange.GFEX)
                    and dt.time() in (time(hour=8, minute=59), time(hour=20, minute=59))
                ):
                    auction_bar = BarData(
                        symbol=req.symbol,
                        exchange=req.exchange,
                        datetime=dt,
                        open_price=float(tp.open),
                        volume=float(tp.volume),
                        turnover=float(tp.amount),
                        gateway_name="XT"
                    )
                    continue

            # 生成K线对象
            bar: BarData = BarData(
                symbol=req.symbol,
                exchange=req.exchange,
                datetime=dt,
                interval=req.interval,
                volume=float(tp.volume),
                turnover=float(tp.amount),
                open_interest=float(tp.openInterest),
                open_price=float(tp.open),
                high_price=float(tp.high),
                low_price=float(tp.low),
                close_price=float(tp.close),
                gateway_name="XT"
            )

            # 合并集合竞价数据
            if auction_bar and auction_bar.volume:
                bar.open_price = auction_bar.open_price
                bar.high_price = max(bar.high_price, auction_bar.open_price)
                bar.low_price = min(bar.low_price, auction_bar.open_price)
                bar.volume += auction_bar.volume
                bar.turnover += auction_bar.turnover
                auction_bar = None

            history.append(bar)

        return history

    def query_tick_history(self, req: HistoryRequest, output: Callable = print) -> list[TickData] | None:
        """查询Tick数据"""
        history: list[TickData] = []

        if not self.inited:
            n: bool = self.init(output)
            if not n:
                return history

        df: DataFrame = get_history_df(req, output)
        if df.empty:
            return history

        # 遍历解析
        for tp in df.itertuples():
            dt: datetime = datetime.fromtimestamp(tp.time / 1000)
            dt = dt.replace(tzinfo=CHINA_TZ)

            bidPrice: list[float] = tp.bidPrice
            askPrice: list[float] = tp.askPrice
            bidVol: list[float] = tp.bidVol
            askVol: list[float] = tp.askVol

            tick: TickData = TickData(
                symbol=req.symbol,
                exchange=req.exchange,
                datetime=dt,
                volume=float(tp.volume),
                turnover=float(tp.amount),
                open_interest=float(tp.openInt),
                open_price=float(tp.open),
                high_price=float(tp.high),
                low_price=float(tp.low),
                last_price=float(tp.lastPrice),
                pre_close=float(tp.lastClose),
                bid_price_1=float(bidPrice[0]),
                ask_price_1=float(askPrice[0]),
                bid_volume_1=float(bidVol[0]),
                ask_volume_1=float(askVol[0]),
                gateway_name="XT",
            )

            bid_price_2: float = float(bidPrice[1])
            if bid_price_2:
                tick.bid_price_2 = bid_price_2
                tick.bid_price_3 = float(bidPrice[2])
                tick.bid_price_4 = float(bidPrice[3])
                tick.bid_price_5 = float(bidPrice[4])

                tick.ask_price_2 = float(askPrice[1])
                tick.ask_price_3 = float(askPrice[2])
                tick.ask_price_4 = float(askPrice[3])
                tick.ask_price_5 = float(askPrice[4])

                tick.bid_volume_2 = float(bidVol[1])
                tick.bid_volume_3 = float(bidVol[2])
                tick.bid_volume_4 = float(bidVol[3])
                tick.bid_volume_5 = float(bidVol[4])

                tick.ask_volume_2 = float(askVol[1])
                tick.ask_volume_3 = float(askVol[2])
                tick.ask_volume_4 = float(askVol[3])
                tick.ask_volume_5 = float(askVol[4])

            history.append(tick)

        return history


def get_history_df(req: HistoryRequest, output: Callable = print) -> DataFrame:
    """获取历史数据DataFrame"""
    symbol: str = req.symbol
    exchange: Exchange = req.exchange
    start_dt: datetime = req.start
    end_dt: datetime = req.end
    interval: Interval = req.interval

    if not interval:
        interval = Interval.TICK

    xt_interval: str | None = INTERVAL_VT2XT.get(interval, None)
    if not xt_interval:
        output(f"迅投研查询历史数据失败：不支持的时间周期{interval.value}")
        return DataFrame()

    # 为了查询夜盘数据
    end_dt += timedelta(1)

    # 从服务器下载获取
    xt_symbol: str = symbol + "." + EXCHANGE_VT2XT[exchange]
    start: str = start_dt.strftime("%Y%m%d%H%M%S")
    end: str = end_dt.strftime("%Y%m%d%H%M%S")

    if exchange in (Exchange.SSE, Exchange.SZSE) and len(symbol) > 6:
        xt_symbol += "O"

    xtdata.download_history_data(xt_symbol, xt_interval, start, end)
    data: dict = xtdata.get_local_data([], [xt_symbol], xt_interval, start, end, -1, "front_ratio", False)      # 默认等比前复权

    df: DataFrame = data[xt_symbol]
    return df
